{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import cv2\n",
    "import requests\n",
    "import json\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from models.config import TABLE_PATH, TEST_PATH, OUTPUT_SHAPE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 配置信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#字典\n",
    "with open(TABLE_PATH, 'r',encoding='gbk') as f:\n",
    "    inv_table = [char.strip() for char in f]+[' ']*2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#测试数据集\n",
    "test_all_image_paths = [TEST_PATH + img for img in sorted(os.listdir(TEST_PATH))]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 调用tensorflow serving方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def text_ocr(image_list):\n",
    "    '''\n",
    "    矩形框文字识别\n",
    "    '''\n",
    "    result_list = []\n",
    "    input_image = []\n",
    "    for img in image_list:\n",
    "        image = cv2.resize(img, (OUTPUT_SHAPE[1], OUTPUT_SHAPE[0]), 3)\n",
    "        image = image.reshape(OUTPUT_SHAPE)\n",
    "        input_image.append(image.tolist())\n",
    "    url = 'http://192.168.46.230:8501/v1/models/crnn:predict'\n",
    "    headers = {\"content-type\": \"application/json\"}\n",
    "    data = json.dumps({\n",
    "            \"signature_name\": 'serving_default',\n",
    "            \"instances\":input_image})\n",
    "    predictions = requests.post(url,data=data, headers=headers,timeout=50).json()\n",
    "    if 'predictions' in predictions:\n",
    "        if len(predictions['predictions']) > 0:\n",
    "            for sentence in predictions['predictions']:\n",
    "                result = ''\n",
    "                for i in range(len(sentence)):\n",
    "                    char_index = np.argmax(sentence[i])\n",
    "                    if char_index != len(inv_table):\n",
    "                        result +=  inv_table[char_index]\n",
    "                result_list.append(result)\n",
    "    return result_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备图片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = tf.io.read_file(test_all_image_paths[4])\n",
    "image1 = tf.image.decode_jpeg(image1, channels=3)\n",
    "image1 = tf.image.resize(image1, [32, 280]).numpy()\n",
    "image2 = tf.io.read_file(test_all_image_paths[0])\n",
    "image2 = tf.image.decode_jpeg(image2, channels=3)\n",
    "image2 = tf.image.resize(image2, [32, 280]).numpy()\n",
    "image_list = [image1,image2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 调用接口"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "识别时间： 0.49924182891845703\n",
      "识别结果： ['卡卡塔尔航空公', '将将驴的头强行按到水中']\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "result = text_ocr(image_list)\n",
    "print('识别时间：', time.time()-start)\n",
    "print('识别结果：', result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
